# -*- coding: utf-8 -*- #
"""*********************************************************************************************"""
#   FileName     [ timit2ark.py ]
#   Synopsis     [ plug-in our own preprocessed features to .ark files for kaldi ]
#   Author       [ Andy T. Liu (Andi611) ]
#   Copyright    [ Copyleft(c), Speech Lab, NTU, Taiwan ]
#   Reference    [ https://github.com/mravanelli/pytorch-kaldi#how-can-i-use-my-own-dataset ]
"""*********************************************************************************************"""

###############
# IMPORTATION #
###############
import os
import pickle
import kaldi_io
import numpy as np
from tqdm import tqdm


############
# SETTINGS #
############
KALDI_ROOT = '/media/andi611/1TBSSD/kaldi/' # !!!!!!!!!!!!!!!!! CHANGE THIS TO YOUR OWN KALDI ROOT !!!!!!!!!!!!!!!!! #
INPUT_PATH = '../data/timit_mel160_phoneme63' # this can be generated with 'preprocess_timit.py'

INPUT_SETS = ['train', 'test'] # you should not need to change this
OUTPUT_SETS = ['train', 'dev', 'test'] # you should not need to change this

TIMIT_PATH = os.path.join(KALDI_ROOT, 'egs/timit/s5/') # you should not need to change this
SOURCE_DIR = os.path.join(TIMIT_PATH, 'data-kaldi-mel') # the data directory generated by kaldi script
OUTPUT_PATH = os.path.join(TIMIT_PATH, 'timit_mel160_arked') # you should not need to change this


########
# MAIN #
########
def main():
    if not os.path.isdir(KALDI_ROOT):
        print('CHANGE THIS TO YOUR OWN KALDI ROOT: ', KALDI_ROOT)
        exit()

    if not os.path.isdir(INPUT_PATH):
        print('Invalid path for the preprocessed timit dataset: ', INPUT_PATH)
        print('Please run \'preprocess_timit.py\' first!')
        exit()
        
    if not os.path.isdir(SOURCE_DIR):
        print('Invalid path for the source directory: ', SOURCE_DIR)
        print('Please read the Wiki page for instructions!')
        exit()

    if not os.path.isdir(OUTPUT_PATH):
        os.mkdir(OUTPUT_PATH)

    # read train and test from the preprocessed directory
    x, ids = [], []
    for s in INPUT_SETS:
        with open(os.path.join(INPUT_PATH, s + '_x.pkl'), 'rb') as fp:
            x += pickle.load(fp)
        with open(os.path.join(INPUT_PATH, s + '_id.pkl'), 'rb') as fp:
            ids += pickle.load(fp)
        assert len(x)==len(ids)
    print('[TIMIT-to-ARK] - ', 'Total Dataset len:', len(x))

    # construct all input dict
    all_inputs = {}
    for idx, i in enumerate(ids):
        i = str(i).strip('.wav').split('/')
        i = i[-2].upper() + '_' + i[-1].upper()
        all_inputs[i] = np.asarray(x[idx])

    # filter all input with kaldi generated files
    for s in OUTPUT_SETS:
        if not os.path.isdir(SOURCE_DIR):
            raise NotADirectoryError('Source directory does not exist!', SOURCE_DIR)    
        
        if not os.path.isdir(OUTPUT_PATH + '/' + str(s)):
            os.mkdir(OUTPUT_PATH + '/' + str(s))
        
        # read train / dev / test from the kaldi generated directory
        partial_outputs = {}
        with open(os.path.join(SOURCE_DIR, s + '/feats.scp'), 'r') as f:
            lines = f.readlines()
            for line in lines:
                line = line.split(' ')[0]
                if line in all_inputs:
                    partial_outputs[line] = all_inputs[line]
            assert len(lines) == len(partial_outputs)

        # writiing output with kaldi_io
        ark_scp_output = 'ark:| copy-feats --compress=true ark:- ark,scp:{}/raw_mel_{}.ark,{}/{}/feats.scp'.format(OUTPUT_PATH, str(s), OUTPUT_PATH, str(s))
        with kaldi_io.open_or_fd(ark_scp_output, 'wb') as f:
            for key, mat in tqdm(partial_outputs.items()): 
                kaldi_io.write_mat(f, mat, key=key)

    print('[TIMIT-to-ARK] - All done, saved at \'' + str(OUTPUT_PATH) + '\' exit.')

if __name__ == '__main__':
    main()
    
